Skip to content

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Oct 8, 2025

Fixes #31396

@jakevdp jakevdp requested a review from hawkinsp October 8, 2025 22:22
@jakevdp jakevdp self-assigned this Oct 8, 2025
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Oct 9, 2025
@jakevdp
Copy link
Collaborator Author

jakevdp commented Oct 13, 2025

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a dtype promotion issue in scalar indexing by ensuring that indices are promoted to a dtype suitable for the array being indexed, rather than just based on the index types themselves. The changes in jax/_src/numpy/indexing.py and jax/_src/checkify.py are robust and address the issue. A regression test is also added. I've found one potential issue where the out-of-bounds check might not respect the newly determined dtype, which could lead to incorrect error checking in some cases.

Comment on lines +609 to +610
index_dtype = lax_utils.int_dtype_for_shape(arr.shape, signed=True)
start_indices = [lax.convert_element_type(idx, index_dtype) for idx in start_indices]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This is a great change to ensure start_indices have a robustly determined dtype. However, the very next call to jnp_error._check_precondition_oob_dynamic_slice seems to counteract this by hardcoding int32 for its index checks.

This could lead to incorrect out-of-bounds checks if lax_utils.int_dtype_for_shape selected a larger type like int64 (e.g., for an array with large dimensions), as the indices would be downcast to int32 inside the check function, potentially causing overflow or incorrect comparisons.

To make this fix fully effective, _check_precondition_oob_dynamic_slice should probably be updated to respect the dtype of the incoming start_indices, similar to the changes made in checkify.py in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect indexing with jnp.int8 and Python scalar

1 participant